/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file load_to_l0a_load2dV2.h
 * \brief
 */

#ifndef IMPL_MATMUL_STAGE_SPLIT_LOAD_TO_L0A_LOAD2DV2_H
#define IMPL_MATMUL_STAGE_SPLIT_LOAD_TO_L0A_LOAD2DV2_H

#include "load_to_l0a_intf.h"
#include "../load_to_l0_utils.h"

namespace AscendC {
namespace Impl {
namespace Detail {
template <typename IMPL, typename A_TYPE, const auto& MM_CFG>
class LoadToL0A<IMPL, A_TYPE, MM_CFG,
    enable_if_t<(GetGemvMode<A_TYPE>() == GemvMode::MATRIX) &&
                (GetLoadInstrType<typename A_TYPE::T, MM_CFG>() == LoadInstrType::LOAD2DV2)>>
{
    using A_T = typename A_TYPE::T;
    using L0A_T = typename Conditional<HasScalePosition<A_TYPE>::value, typename GetL0DataType<typename A_TYPE::T, true>::Type, typename GetL0DataType<typename A_TYPE::T, false>::Type>::type;
    using AuxDtype = decltype(GetAuxDataType<A_TYPE>());
public:
    __aicore__ inline LoadToL0A() {};
    __aicore__ inline ~LoadToL0A() {};

    __aicore__ inline void Prepare(bool isATranspose, uint16_t aL1K, uint16_t aL1M) const {};
    __aicore__ inline void SetScalar(A_T scalar) {};

    __aicore__ inline void Load(const LocalTensor<L0A_T> &dst, const LocalTensor<A_T> &aMatrix,
     uint16_t aL1M, uint16_t aL1K, uint16_t madM, uint16_t madK, uint16_t aL1MOffset, uint16_t aL1KOffset,
     bool isATranspose, const LocalTensor<AuxDtype> &l1AAuxMatrix = {}, uint16_t aAuxL1K = 0,
     uint16_t aAuxL1KOffset = 0) const
    {
        if (isATranspose) {
            // Mx should run for MXLoad
            if constexpr (!HasScalePosition<A_TYPE>::value) {
                TransLoadDataToL0(dst, aMatrix, aL1K, madM, madK, aL1MOffset, aL1KOffset);
            } else {
                MxTransLoadDataToL0(dst, aMatrix, aL1K, madM, madK, aL1MOffset, aL1KOffset, l1AAuxMatrix, aAuxL1K, aAuxL1KOffset);
            }
        } else {
            // Mx should run for MXLoad
            if constexpr (!HasScalePosition<A_TYPE>::value) {
                LoadDataToL0(dst, aMatrix, aL1M, madM, madK, aL1MOffset, aL1KOffset);
            } else {
                MxLoadDataToL0(dst, aMatrix, aL1M, madM, madK, aL1MOffset, aL1KOffset, l1AAuxMatrix, aAuxL1K, aAuxL1KOffset);
            }
        }
    }
private:
    constexpr static int32_t c0Size_ = AuxGetC0Size<A_T>();

    __aicore__ inline void TransLoadDataToL0(const LocalTensor<A_T> &dst, const LocalTensor<A_T> &aMatrix,
     uint16_t aL1K, uint16_t madM, uint16_t madK, uint16_t aL1MOffset, uint16_t aL1KOffset) const
    {
        LoadData2DParamsV2 loadDataParams;
        loadDataParams.mStartPosition = CeilDiv(aL1KOffset, BLOCK_CUBE);
        loadDataParams.kStartPosition = CeilDiv(aL1MOffset, c0Size_);
        loadDataParams.kStep = CeilDiv(madM, c0Size_);
        if constexpr (IsSameType<A_T, float>::value) {
            // K step must be multiples of 2 when transpose is enabled ane .type = .b32
            loadDataParams.kStep = CeilAlign(loadDataParams.kStep, K_STEP_MIN_VAL_B32);
        }
        loadDataParams.srcStride = CeilDiv(aL1K, ALIGN_NUM);
        loadDataParams.dstStride = CeilDiv(madM, ALIGN_NUM);
        loadDataParams.ifTranspose = true;
        loadDataParams.mStep = CeilDiv(madK, HW_M0);
        if constexpr (IsSupportB4<A_T>()) {
            // M step must be multiples of 4 when transpose is enabled and .type = .b4
            loadDataParams.mStep = CeilAlign(loadDataParams.mStep, M_STEP_MIN_VAL_B4);
        }

        if constexpr (IsSupportB8<A_T>()) {
            // M step must be multiples of 2 when transpose is enabled and .type = .b8
            uint16_t l0ALoop = CeilAlign(loadDataParams.mStep, M_STEP_MIN_VAL_B8) / M_STEP_MIN_VAL_B8;
            uint64_t dstOffset = 0;
            uint64_t dstAddrStride = CeilAlign(madM, ALIGN_NUM) * ONE_BLK_SIZE;
            loadDataParams.mStep = M_STEP_MIN_VAL_B8;
            uint16_t oriMstartPos = loadDataParams.mStartPosition;
            // K aixs is m direction, and M aixs is k direction in load2dv2 intrin
            for (uint16_t idx = 0; idx < l0ALoop; ++idx) {
                loadDataParams.mStartPosition = oriMstartPos + M_STEP_MIN_VAL_B8 * idx;
                LoadData(dst[dstOffset], aMatrix, loadDataParams);
                dstOffset += dstAddrStride;
            }
        } else if constexpr (IsSameType<A_T, float>::value) {
            // in case of mdl && basek=8, the unit of mStartPosition is 16, so don't use it
            loadDataParams.mStartPosition = 0;
            loadDataParams.kStartPosition = 0;
            uint64_t matrixOffset = aL1MOffset * aL1K + aL1KOffset * B32_C0SIZE;
            LoadData(dst, aMatrix[matrixOffset], loadDataParams);
        } else {
            LoadData(dst, aMatrix, loadDataParams);
        }
    }

    __aicore__ inline void LoadDataToL0(const LocalTensor<A_T> &dst, const LocalTensor<A_T> &aMatrix,
     uint16_t aL1M, uint16_t madM, uint16_t madK, uint16_t aL1MOffset, uint16_t aL1KOffset) const
    {
        LoadData2DParamsV2 loadDataParams;
        loadDataParams.mStartPosition = CeilDiv(aL1MOffset, BLOCK_CUBE);
        loadDataParams.kStartPosition = CeilDiv(aL1KOffset, c0Size_);
        loadDataParams.mStep = CeilDiv(madM, HW_M0);
        loadDataParams.kStep = CeilDiv(madK, c0Size_);
        loadDataParams.srcStride = CeilDiv(aL1M, ALIGN_NUM);
        loadDataParams.dstStride = CeilDiv(madM, ALIGN_NUM);
        loadDataParams.ifTranspose = false;
        LoadData(dst, aMatrix, loadDataParams);
    }

    __aicore__ inline void MxTransLoadDataToL0(const LocalTensor<L0A_T> &dst, const LocalTensor<A_T> &aMatrix,
     uint16_t aL1K, uint16_t madM, uint16_t madK, uint16_t aL1MOffset, uint16_t aL1KOffset,
     const LocalTensor<AuxDtype> &l1AAuxMatrix, uint16_t aAuxL1K, uint16_t aAuxL1KOffset) const
    {
#if defined(__DAV_C310__)
        uint16_t mStartPos = CeilDiv(aL1MOffset, ALIGN_NUM);
        uint16_t mStep = CeilDiv(madM, HW_M0);
        uint16_t kStep = CeilDiv(madK, c0Size_);
        uint16_t srcStride = CeilDiv(aL1K, HW_M0);
        uint16_t dstStride = CeilDiv(madM, HW_M0);

        uint16_t dataMStartPos = CeilDiv(aL1KOffset, ALIGN_NUM);
        uint16_t dataKStartPos = CeilDiv(aL1MOffset, c0Size_);
        uint16_t dataMStep = CeilDiv(madK, HW_M0);
        uint16_t dataKStep = CeilDiv(madM, c0Size_);
        uint16_t dataSrcStride = CeilDiv(aL1K, HW_M0);
        uint16_t dataDstStride = CeilDiv(madM, HW_M0);

        LoadData2DParamsV2 loadDataParams;
        loadDataParams.mStartPosition = dataMStartPos;
        loadDataParams.kStartPosition = dataKStartPos;
        loadDataParams.mStep = dataMStep;
        loadDataParams.kStep = dataKStep;
        loadDataParams.srcStride = dataSrcStride;
        loadDataParams.dstStride = dataDstStride;
        loadDataParams.ifTranspose = true;

        LoadData2DMxParams loadDataMxParams;
        loadDataMxParams.xStartPosition = mStartPos;
        loadDataMxParams.xStep = mStep;
        if constexpr (SupportType<A_T, float4_e2m1_t, float4_e1m2_t>()) {
            uint16_t scaleKStartPos = CeilDiv(aAuxL1KOffset, FP4_TWO);
            uint16_t dstScaleStride = CeilDiv(madK, c0Size_);
            uint16_t srcScaleStride = CeilDiv(aAuxL1K, FP4_TWO);
            loadDataMxParams.yStartPosition = scaleKStartPos;
            loadDataMxParams.yStep = kStep;
            loadDataMxParams.srcStride = srcScaleStride;
            loadDataMxParams.dstStride = dstScaleStride;
        } else if constexpr (SupportType<A_T, float8_e5m2_t, float8_e4m3_t>()) {
            // for FP8 ,two K0 on the k axis correspond to a small z fractal.
            uint16_t scaleKStartPos = CeilDiv(aAuxL1KOffset, FP8_TWO);
            uint16_t scaleKStep = CeilDiv(madK, c0Size_* FP8_TWO);
            uint16_t srcScaleStride = CeilDiv(aAuxL1K, FP8_TWO);
            uint16_t dstScaleStride = CeilDiv(madK, c0Size_* FP8_TWO);
            loadDataMxParams.yStartPosition = scaleKStartPos;
            loadDataMxParams.yStep = scaleKStep;
            loadDataMxParams.srcStride = srcScaleStride;
            loadDataMxParams.dstStride = dstScaleStride;
        }
        LoadData(dst, aMatrix, l1AAuxMatrix, loadDataParams, loadDataMxParams);
#endif
    }

    __aicore__ inline void MxLoadDataToL0(const LocalTensor<L0A_T> &dst, const LocalTensor<A_T> &aMatrix,
     uint16_t aL1M, uint16_t madM, uint16_t madK, uint16_t aL1MOffset, uint16_t aL1KOffset,
     const LocalTensor<AuxDtype> &l1AAuxMatrix, uint16_t aAuxL1K, uint16_t aAuxL1KOffset) const
    {
#if defined(__DAV_C310__)
        uint16_t mStartPos = CeilDiv(aL1MOffset, BLOCK_CUBE);
        uint16_t kStartPos = CeilDiv(aL1KOffset, c0Size_);
        uint16_t mStep = CeilDiv(madM, HW_M0);
        uint16_t kStep = CeilDiv(madK, c0Size_);
        uint16_t srcStride = CeilDiv(aL1M, HW_M0);
        uint16_t dstStride = CeilDiv(madM, HW_M0);

        LoadData2DParamsV2 loadDataParams;
        loadDataParams.mStartPosition = mStartPos;
        loadDataParams.kStartPosition = kStartPos;
        loadDataParams.mStep = mStep;
        loadDataParams.kStep = kStep;
        loadDataParams.srcStride = srcStride;
        loadDataParams.dstStride = dstStride;

        LoadData2DMxParams loadDataMxParams;
        loadDataMxParams.xStartPosition = mStartPos;
        loadDataMxParams.xStep = mStep;
        if constexpr (SupportType<A_T, float4_e2m1_t, float4_e1m2_t>()) {
            uint16_t scaleKStartPos = CeilDiv(aAuxL1KOffset, FP4_TWO);
            uint16_t srcScaleStride = CeilDiv(aAuxL1K, FP4_TWO);
            uint16_t dstScaleStride = CeilDiv(madK, c0Size_);
            loadDataMxParams.yStartPosition = scaleKStartPos;
            loadDataMxParams.yStep = kStep;
            loadDataMxParams.srcStride = srcScaleStride;
            loadDataMxParams.dstStride = dstScaleStride;
        } else if constexpr (SupportType<A_T, float8_e5m2_t, float8_e4m3_t>()) {
            // for FP8 ,two K0 on the k axis correspond to a small z fractal.
            uint16_t scaleKStartPos = CeilDiv(aAuxL1KOffset, FP8_TWO);
            uint16_t scaleKStep = CeilDiv(madK, c0Size_ * FP8_TWO);
            uint16_t srcScaleStride = CeilDiv(aAuxL1K, FP8_TWO);
            uint16_t dstScaleStride = CeilDiv(madK, c0Size_ * FP8_TWO);
            loadDataMxParams.yStartPosition = scaleKStartPos;
            loadDataMxParams.yStep = scaleKStep;
            loadDataMxParams.srcStride = srcScaleStride;
            loadDataMxParams.dstStride = dstScaleStride;
        }
        LoadData(dst, aMatrix, l1AAuxMatrix, loadDataParams, loadDataMxParams);
#endif
    }
};

}  // namespace Detail
}  // namespace Impl
}  // namespace AscendC
#endif // IMPL_MATMUL_STAGE_SPLIT_LOAD_TO_L0A_LOAD2DV2_H